""" explain_rcexplainer.py

    Implementation of the RCExplainer method without LDBs as described in 
    Robust Counterfactual Explanations on Graph Neural Networks
"""

import math
import time
import os

import matplotlib
import matplotlib.colors as colors
import matplotlib.pyplot as plt
from matplotlib.backends.backend_agg import FigureCanvasAgg as FigureCanvas
from matplotlib.figure import Figure

import random

import networkx as nx
import numpy as np
import pandas as pd
import seaborn as sns
import tensorboardX.utils

import torch
import torch.nn as nn
from torch.autograd import Variable

import sklearn.metrics as metrics
from sklearn.metrics import roc_auc_score, recall_score, precision_score, roc_auc_score, precision_recall_curve
from sklearn.cluster import DBSCAN

import pdb
import pickle

import utils.io_utils as io_utils
import utils.train_utils as train_utils
import utils.noise_utils as noise_utils

import dnn_invariant.extract_rules as extract
import utils.graph_utils as graph_utils
import utils.accuracy_utils3 as accuracy_utils
import utils.neighbor_utils as neighbor_utils
import explainer.explain as explain

from scipy.sparse import coo_matrix

use_cuda = torch.cuda.is_available()
FloatTensor = torch.cuda.FloatTensor if use_cuda else torch.FloatTensor
LongTensor = torch.cuda.LongTensor if use_cuda else torch.LongTensor
Tensor = FloatTensor


nbr_data = None
rule_dict_node = None

use_comb_mask = True
avg_add_edges = 0.
avg_removed_edges = 0.
global_noise_count = 0.
global_mask_dense_count = 0.
global_mask_density = 0.

ent_cf = -1.0
size_cf = -1.0
lap_cf = -1.0

sub_label_nodes = None
sub_label_array = None


class ExplainerRCExplainerNoLDB(explain.Explainer):
    def __init__(
        self,
        model,
        adj,
        feat,
        label,
        pred,
        emb,
        train_idx,
        args,
        writer=None,
        print_training=False,
        graph_mode=False,
        graph_idx=False,
        num_nodes = None,
        device='cpu'
    ):
        super().__init__(model, adj, feat, label, pred, train_idx, args, writer, print_training, graph_mode, graph_idx, num_nodes,device)

        self.coeffs = {
            "t0": 0.5,
            "t1": 4.99,
        }

        self.emb = emb
        self.model.eval()

    def extract_neighborhood_emb(self, node_idx, graph_idx=0):
        node_idx_new, sub_adj, sub_feat, sub_label, neighbors = super().extract_neighborhood(node_idx, graph_idx)
        sub_embs = self.emb[graph_idx, neighbors]
        return node_idx_new, sub_adj, sub_feat, sub_label, neighbors, sub_embs

    def get_nbr_data(self, args, node_indices, graph_idx=0):
        torch_data = super().get_nbr_data(args, node_indices, graph_idx)

        return torch_data

    # GRAPH EXPLAINER
    def eval_graphs(self, args, graph_indices, explainer):
        if self.args.apply_filter and self.args.bmname == 'Mutagenicity':
            #h_edges = accuracy_utils.gethedgesmutag()
            graph_indices, h_edges = accuracy_utils.filterMutag2(graph_indices, self.label, self.feat, self.adj,
                                                                 self.num_nodes)

        if args.draw_graphs:
            random.shuffle(graph_indices)

            graph_indices = graph_indices[:5]

        global global_noise_count
        global noise_diff_count
        global avg_add_edges
        global avg_removed_edges
        ep_variance = 0.

        incorrect_preds = 0.

        self.model.eval()
        num_classes = self.pred[0][0].shape[0]
        flips = np.zeros((num_classes))

        inv_flips = np.zeros((num_classes))
        topk_inv_flips = np.zeros((num_classes))
        pos_diff = np.zeros((num_classes))
        inv_diff = np.zeros((num_classes))
        topk_inv_diff = np.zeros((num_classes))
        total = np.zeros((num_classes))

        masked_adjs = []
        skipped_iters = 0.
        logging_graphs = False

        avg_noise_diff = 0.
        noise_diff_count = 0.
        avg_adj_diff = 0.
        acc_count = 0.

        avg_mask_density = 0.
        AUC = accuracy_utils.AUC()
        noise_AUC = accuracy_utils.AUC()

        avg_pred_diff = 0.
        pred_removed_edges = 0.
        topk = self.args.topk

        noise_iters = 1
        noise_range = [0, 0.05, 0.1, 0.15, 0.2, 0.25, 0.3]
        # noise_handlers = [noise_utils.NoiseHandler("PGExplainer", self.model, self, noise_percent=x) for x in noise_range]

        graph_indices = list(graph_indices)
        np.random.shuffle(graph_indices)
        explainer.eval()
        explainer_sum = 0.0
        model_sum = 0.0
        mean_auc = 0.0

        for p in explainer.parameters():
            explainer_sum += torch.sum(p).item()
        for p in self.model.parameters():
            model_sum += torch.sum(p).item()
        print("sum of params of loaded explainer: {}".format(explainer_sum))
        print("sum of params of loaded model: {}".format(model_sum))

        stats = accuracy_utils.Stats("PGExplainer_Boundary", explainer, self.model)

        for graph_idx in graph_indices:
            with torch.no_grad():
                # print("doing for graph index: ", graph_idx)

                # extract features and

                """Explain a single node prediction
                """
                # index of the query node in the new adj

                if len(self.adj.shape) < 3:
                    sub_adj = self.adj
                else:
                    sub_adj = self.adj[graph_idx]
                sub_feat = self.feat[graph_idx, :]
                sub_label = self.label[graph_idx]

                if self.num_nodes is not None:
                    sub_nodes = self.num_nodes[graph_idx]
                else:
                    sub_nodes = None

                sub_adj = np.expand_dims(sub_adj, axis=0)
                sub_feat = np.expand_dims(sub_feat, axis=0)

                adj = torch.tensor(sub_adj, dtype=torch.float)
                x = torch.tensor(sub_feat, requires_grad=True, dtype=torch.float)
                label = sub_label.clone().detach()

                emb = self.model.getEmbeddings(x, adj, batch_num_nodes=[sub_nodes.cpu().numpy()])
                emb = emb.clone().detach()

                pred_label = np.argmax(self.pred[0][graph_idx], axis=0)
                if pred_label != label.item():
                    incorrect_preds += 1
                # print("Graph predicted label: ", pred_label)

                t0 = 5.0
                t1 = 5.0

                tmp = float(t0 * np.power(t1 / t0, 1.0))

                pred, masked_adj, _, _, inv_pred = explainer((x[0], emb[0], adj[0], tmp, label, sub_nodes),
                                                             training=False)

                pred_try, _ = self.model(x.cuda(), adj.cuda(), batch_num_nodes=[sub_nodes.cpu().numpy()])
                # print("pred debug: ", self.pred[0][graph_idx], pred_try, pred, inv_pred)

                flips[pred_label] += 1.0

                inv_flips[pred_label] += 1.0
                total[pred_label] += 1.0

                pred_try = nn.Softmax(dim=0)(pred_try[0])
                pred_t = torch.from_numpy(self.pred[0][graph_idx]).float().cuda()
                pred_t = nn.Softmax(dim=0)(pred_t)

                pos_diff[pred_label] += (pred_t[pred_label] - pred[0][pred_label]).item()
                inv_diff[pred_label] += (pred_t[pred_label] - inv_pred[0][pred_label]).item()

                masked_adj = masked_adj.cpu().detach().numpy() * sub_adj.squeeze()

                topk_adj = noise_utils.filterTopK(masked_adj, sub_adj[0], k=topk)

                # print("Adj: ", np.sum(topk_adj), np.sum(sub_adj[0]))

                topk_adj_t = torch.from_numpy(topk_adj).float().cuda()
                pred_topk, _ = self.model(x.cuda(), topk_adj_t.unsqueeze(0), batch_num_nodes=[sub_nodes.cpu().numpy()])

                pred_topk = nn.Softmax(dim=0)(pred_topk[0])

                topk_inv_diff[pred_label] += (pred_t[pred_label] - pred_topk[pred_label]).item()

                if self.args.post_processing:
                    masked_adj = accuracy_utils.getModifiedMask(masked_adj, sub_adj[0], sub_nodes.cpu().numpy())

                if self.args.bmname == 'Mutagenicity' and self.args.apply_filter:
                    if self.args.post_processing:
                        masked_adj = accuracy_utils.getModifiedMask(masked_adj, sub_adj[0], sub_nodes.cpu().numpy())
                    ht_edges = h_edges[graph_idx]

                    AUC.addEdges(masked_adj, ht_edges)

                    # mean_auc += AUC.getAUC()

                if torch.argmax(pred_topk) == pred_label:
                    topk_inv_flips[pred_label] += 1.0

                variance = np.sum(np.abs(masked_adj - 0.5) * sub_adj.squeeze()) / np.sum(sub_adj)
                ep_variance += variance

                imp_nodes = explain.getImportantNodes(masked_adj, 8)
                # h_nodes = []
                # h_nodes = explain.getHnodes(graph_idx, 0)
                # h_nodes.extend(explain.getHnodes(graph_idx, 1))
                # #!!!!
                # h_nodes = accuracy_utils.getHNodes(graph_idx, sub_label_nodes, sub_label_array, self.args)
                # #
                # ht_edges = accuracy_utils.getHTEdges(h_nodes, sub_adj[0])
                # #
                # AUC.addEdges(masked_adj, ht_edges)
                # #
                # # # AUC.addEdges(masked_adj, h_nodes, sub_adj[0], dataset='synthetic')
                # #
                # mAP_s = accuracy_utils.getmAP(masked_adj, h_nodes)
                # mAP += mAP_s
                #
                # top4_acc, top6_acc, top8_acc = accuracy_utils.getAcc(imp_nodes, h_nodes)
                #
                # avg_top4_acc += top4_acc
                # avg_top6_acc += top6_acc
                # avg_top8_acc += top8_acc

                stats.update(masked_adj, imp_nodes, adj, x, None, sub_nodes)

                if args.draw_graphs:

                    gt_mask = sub_adj[0] - 0.9
                    for e in ht_edges.keys():
                        gt_mask[e[0], e[1]] = 1.0
                        gt_mask[e[1], e[0]] = 1.0

                    # gt_mask
                    accuracy_utils.saveAndDrawGraph(gt_mask, sub_adj[0], sub_feat[0],
                                                    self.num_nodes[graph_idx].item(),
                                                    self.args,
                                                    label.item(), pred_label, graph_idx,
                                                    prob=pred_t[pred_label],
                                                    plt_path=None, adj_mask_bool=True, prefix="gt_")

                    accuracy_utils.saveAndDrawGraph(masked_adj, sub_adj[0], sub_feat[0],
                                                    self.num_nodes[graph_idx].item(),
                                                    self.args,
                                                    label.item(), pred_label, graph_idx,
                                                    prob=pred_t[pred_label],
                                                    plt_path=None, adj_mask_bool=True)
                    # topk_adj
                    accuracy_utils.saveAndDrawGraph(None, topk_adj, sub_feat[0], self.num_nodes[graph_idx].item(),
                                                    self.args,
                                                    label.item(), pred_label, graph_idx, prob=pred_topk[pred_label],
                                                    plt_path=None, adj_mask_bool=False)

                if self.args.inverse_noise:
                    adj_np = adj[0].cpu().numpy()
                    new_adj, added_edges, removed_edges = noise_utils.addNoiseToGraphInverse(masked_adj, adj_np,
                                                                                             self.args.noise_percent)
                    pred_removed_edges += removed_edges

                    new_adj_t = torch.from_numpy(new_adj).float().cuda()

                    pred_masked, _ = self.model(x.cuda(), new_adj_t.unsqueeze(0),
                                                batch_num_nodes=[sub_nodes.cpu().numpy()])
                    pred_masked_sfmx = torch.nn.functional.softmax(pred_masked[0]).detach().cpu().numpy()
                    pred_sfmx = torch.nn.functional.softmax(
                        torch.from_numpy(self.pred[0][graph_idx]).float()).cpu().numpy()
                    pred_diff = np.abs(pred_sfmx[pred_label] - pred_masked_sfmx[pred_label])
                    avg_pred_diff += pred_diff

                acc_count += 1.0
                # print("adj: ", np.sum(masked_adj), np.sum(adj.cpu().numpy()))
                mask_density = np.sum(masked_adj) / np.sum(adj.cpu().numpy())
                avg_mask_density += mask_density
                label = self.label[graph_idx]

                if self.args.noise:
                    for n_iter in range(noise_iters):
                        for nh in noise_handlers:
                            try:
                                noise_feat, noise_adj = nh.sample(sub_feat[0], sub_adj[0], sub_nodes)
                            except:
                                continue

                            emb_noise = self.model.getEmbeddings(noise_feat.unsqueeze(0), noise_adj.unsqueeze(0), [sub_nodes.cpu().numpy()])
                            pred_n, masked_adj_n, _, _, _ = explainer((noise_feat, emb_noise[0], noise_adj, tmp, label, sub_nodes), training=False)
                            masked_adj_n = masked_adj_n.cpu().detach() * noise_adj

                            nh.update(masked_adj, masked_adj_n.cpu().detach().numpy(), sub_adj[0], noise_adj.cpu().detach().numpy(), None, graph_idx)
        if not args.noise:
            global_noise_count = 1.0
            noise_diff_count = 1.0
        else:
            global_noise_count = noise_diff_count = 1
        eval_dir = os.path.dirname(self.args.exp_path)
        eval_file = "eval_" + self.args.bmname + "_" + self.args.explainer_method + ".txt"

        eval_file = os.path.join(eval_dir, eval_file)

        myfile = open(eval_file, "a")

        myfile.write("\n \n \n {}".format(self.args.bmname))
        myfile.write("\n method: {}".format(self.args.explainer_method))
        myfile.write("\n bloss version: {}".format(self.args.bloss_version))
        myfile.write("\n ckpt dir: {}".format(self.args.ckptdir))
        myfile.write("\n exp_path: {}".format(self.args.exp_path))
        myfile.write("\n explainer params sum: {}, model params sum: {}".format(explainer_sum, model_sum))

        myfile.write("\n use comb: {},  size cf: {}, ent cf {}".format(use_comb_mask, size_cf, ent_cf))

        if self.args.bmname == 'Mutagenicity' and self.args.apply_filter:
            print(
                "ROC AUC score: {}".format(AUC.getAUC())
            )
            myfile.write("\n ROC AUC score: {}".format(AUC.getAUC()))

        total[total < 0.5] = 1.0

        print(
            "noise percent: {}, inverse noise: {}".format(self.args.noise_percent, self.args.inverse_noise)
        )
        myfile.write("\n noise percent: {}, inverse noise: {}".format(self.args.noise_percent, self.args.inverse_noise))

        print(
            "avg removed edges: {}".format(avg_removed_edges / global_noise_count)
        )
        myfile.write("\n avg removed edges: {}".format(avg_removed_edges / global_noise_count))

        print(
            "pred removed edges: {}".format(pred_removed_edges / acc_count)
        )
        myfile.write("\n pred removed edges: {}".format(pred_removed_edges / acc_count))

        print(
            "avg added edges: {}".format(avg_add_edges / global_noise_count)
        )
        myfile.write("\n avg added edges: {}".format(avg_add_edges / global_noise_count))
        print(
            "avg adj diff: {}".format(avg_adj_diff / global_noise_count)
        )
        print(
            "avg noise diff: {}".format(avg_noise_diff / noise_diff_count)
        )
        myfile.write("\n avg noise diff: {}".format(avg_noise_diff / noise_diff_count))

        print(
            "avg pred diff: {}".format(avg_pred_diff / acc_count)
        )
        myfile.write("\n avg pred diff: {}".format(avg_pred_diff / acc_count))

        print(
            "skipped iters: {}".format(skipped_iters)
        )
        myfile.write("\n skipped iters: {}".format(skipped_iters))

        print(
            "Average mask density: {}".format(avg_mask_density / acc_count)
        )
        myfile.write("\n Average mask density: {}".format(avg_mask_density / acc_count))

        print(
            "pos diff: {}, inv diff: {}, k: {}, topk inv diff: {}".format(pos_diff / total, inv_diff / total, topk,
                                                                          topk_inv_diff / total)

        )

        myfile.write("\n pos diff: {}, inv diff: {}, topk inv diff: {}".format(pos_diff / total, inv_diff / total,
                                                                               topk_inv_diff / total))

        print("Variance: ", ep_variance / acc_count)
        myfile.write("\n Variance: {}".format(ep_variance / acc_count))

        print("Flips: ", flips)
        print("inv Flips: ", inv_flips)
        print("topk inv Flips: ", topk_inv_flips)
        print("Incorrect preds: ", incorrect_preds)
        print("Total: ", total)

        myfile.write(
            "\n flips: {}, Inv flips: {}, topk: {}, topk Inv flips: {}, Incorrect preds: {}, Total: {}".format(flips,
                                                                                                               inv_flips,
                                                                                                               self.args.topk,
                                                                                                               topk_inv_flips,
                                                                                                               incorrect_preds,
                                                                                                               total))

        print(stats)
        myfile.write(str(stats))
        if self.args.noise:
            for nh in noise_handlers:
                print(nh)

            print("NOISE SUMMARY")
            for nh in noise_handlers:
                print(nh.summary())

        print("SUMMARY")
        print(stats.summary())
        myfile.close()

    # GRAPH EXPLAINER
    def explain_graphs(self, args, graph_indices, test_graph_indices=None):
        """
        Explain graphs.
        """

        graph_indices = list(graph_indices)

        explainer = ExplainModule(
            model = self.model,
            num_nodes = self.adj.shape[1],
            emb_dims = self.model.embedding_dim * self.model.num_layers * 2, # TODO: fixme!
            device=self.device,
            args = self.args
        )

        if self.args.eval or self.args.exp_path != "":
            if self.args.eval and self.args.exp_path == "":
                print("no explainer file to load")
                exit()
            else:
                print("loading initial explainer ckpt from : ", self.args.exp_path)

            state_dict = torch.load(self.args.exp_path)
            exp_state_dict = explainer.state_dict()
            for name, param in state_dict.items():
                if name in exp_state_dict and not ("model" in name):
                    exp_state_dict[name].copy_(param)
            explainer.load_state_dict(exp_state_dict)

            if self.args.eval:
                self.eval_graphs(args, graph_indices, explainer)
                exit()

        if self.args.bmname == "Mutagenicity":
            if self.args.apply_filter:
                graph_indices = accuracy_utils.filterMutag(graph_indices, self.label)
                random.shuffle(graph_indices)
            size = 3000

        # elif self.args.bmname == "Mutagenicity":
        #     size = 3035

        else:
            print(self.args.bmname + " not found!")
            assert (False)

        train_data = (self.adj[:size], self.feat[:size], self.label[:size], self.num_nodes[:size])
        val_data = (self.adj[size - 100:], self.feat[size - 100:], self.label[size - 100:], self.num_nodes[size - 100:])
        rule_dict = extract.extract_rules(self.args.bmname, train_data, val_data, args, self.model.state_dict(), graph_indices=None, pool_size=args.pool_size)
        params_optim = []
        for name,param in explainer.named_parameters():
            if "model" in name:
                continue
            params_optim.append(param)


        scheduler, optimizer = train_utils.build_optimizer(self.args, params_optim)


        # def shuffle_forward(l):
        #     order = list(range(len(l)))
        #     random.shuffle(order)
        #     return order
        #
        # def shuffle_backward(l):
        #     l_out = [0] * len(l)
        #     for i, j in enumerate(l):
        #         l_out[j] = l[i]
        #     return l_out

        log_name = self.args.prefix + "_logdir"
        log_path = os.path.join(self.args.ckptdir, log_name)
        if os.path.isdir(log_path):
            print("log dir already exists and will be overwritten")
            time.sleep(5)
        else:
            os.mkdir(log_path)

        training = True
        if self.args.gumbel:
            training = False

        log_file = self.args.prefix + "log_rcexpnoldb_" + self.args.bmname + ".txt"
        log_file_path = os.path.join(log_path, log_file)
        myfile = open(log_file_path, "a")

        myfile.write("\n \n \n {}".format(self.args.bmname))
        myfile.write("\n method: {}".format(self.args.explainer_method))
        myfile.write("\n bloss version: {}, node mask: {}, apply filter: {}".format(self.args.bloss_version, self.args.node_mask, self.args.apply_filter))
        myfile.write("\n exp_path: {}".format(self.args.exp_path))
        myfile.write("\n opt: {}".format(self.args.opt_scheduler))
        myfile.write("\n gumbel: {}, training: {}".format(self.args.gumbel, training))

        myfile.write("\n lr: {}, bound cf: {}, size cf: {}, ent cf {}, inv cf {}".format(self.args.lr, self.args.boundary_c, self.args.size_c, self.args.ent_c,  self.args.inverse_boundary_c))
        myfile.close()
        bloss_prev = None

        ep_count = 0.
        loss_ep = 0.

        for epoch in range(self.args.start_epoch, self.args.num_epochs):
            myfile = open(log_file_path, "a")
            loss = 0
            logging_graphs=False

            masked_adjs = []
            rule_top4_acc = 0.
            rule_top6_acc = 0.
            rule_top8_acc = 0.

            rule_acc_count = 0.
            avg_mask_density = 0.
            mAP = 0.
            bloss_ep = 0.
            flips = 0.
            inv_flips = 0.
            pos_diff = 0.
            inv_diff = 0.
            topk_inv_diff = 0.
            topk_inv_flips = 0.

            incorrect_preds = 0.
            ep_variance = 0.

            AUC = accuracy_utils.AUC()
            stats = accuracy_utils.Stats("PGExplainer_Boundary", explainer, self.model)

            np.random.shuffle(graph_indices)
            explainer.train()


            for graph_idx in graph_indices:
                # print("doing for graph index: ", graph_idx)

                # extract features and

                """Explain a single node prediction
                """
                # index of the query node in the new adj

                if len(self.adj.shape) < 3:
                    sub_adj = self.adj
                else:
                    sub_adj = self.adj[graph_idx]
                sub_feat = self.feat[graph_idx, :]
                sub_label = self.label[graph_idx]

                if self.num_nodes is not None:
                    sub_nodes = self.num_nodes[graph_idx]
                else:
                    sub_nodes = None
                neighbors = np.asarray(range(self.adj.shape[0])) #1,2,3....num_nodes

                sub_adj = np.expand_dims(sub_adj, axis=0)
                sub_feat = np.expand_dims(sub_feat, axis=0)

                adj   = torch.tensor(sub_adj, dtype=torch.float)
                x     = torch.tensor(sub_feat, requires_grad=True, dtype=torch.float)
                label = sub_label.clone().detach()

                emb = self.model.getEmbeddings(x, adj, batch_num_nodes=[sub_nodes.cpu().numpy()])
                emb = emb.clone().detach()

                gt_pred, gt_embedding = self.model(x.cuda(), adj.cuda(), batch_num_nodes=[sub_nodes.cpu().numpy()])

                pred_label = np.argmax(self.pred[0][graph_idx], axis=0)
                if pred_label != label.item():
                    incorrect_preds += 1

                t0 = 0.5
                t1 = 4.99

                tmp = float(t0 * np.power(t1 / t0, epoch /self.args.num_epochs))

                rule_ix = rule_dict['idx2rule'][graph_idx]
                rule = rule_dict['rules'][rule_ix]
                rule_label = rule['label']

                boundary_list = []
                for b_num in range(len(rule['boundary'])):

                    boundary = torch.from_numpy(rule['boundary'][b_num]['basis'])
                    if self.args.gpu:
                        boundary = boundary.cuda()
                    boundary_label = rule['boundary'][b_num]['label']
                    boundary_list.append(boundary)


                pred, masked_adj, graph_embedding, inv_embedding, inv_pred = explainer((x[0], emb[0], adj[0], tmp, label, sub_nodes), training=training)

                # print("prefix: ", self.args.prefix)
                loss, bloss_s = explainer.loss(pred, pred_label, inv_pred=inv_pred)

                loss_ep += loss.item()

                pred_t = torch.from_numpy(self.pred[0][graph_idx]).float().cuda()
                pred_t = nn.Softmax(dim=0)(pred_t)

                if self.args.boundary_c > 0.0:

                    if torch.argmax(pred[0]) != pred_label:
                        flips += 1.0

                    pos_diff += (pred_t[pred_label] - pred[0][pred_label]).item()

                if self.args.inverse_boundary_c > 0.0:
                    if torch.argmax(inv_pred[0]) == pred_label:
                        inv_flips += 1.0
                    inv_diff += (pred_t[pred_label] - inv_pred[0][pred_label]).item()

                masked_adj = masked_adj.cpu().detach().numpy() * sub_adj.squeeze()

                imp_nodes = explain.getImportantNodes(masked_adj, 8)

                stats.update(masked_adj, imp_nodes, adj, x, None, sub_nodes)

                variance = np.sum(np.abs(masked_adj - 0.5)*sub_adj.squeeze()) / np.sum(sub_adj)
                ep_variance += variance

                if epoch%10 == 0:
                    topk_adj = noise_utils.filterTopK(masked_adj, sub_adj[0], k=self.args.topk)
                    topk_adj_t = torch.from_numpy(topk_adj).float().cuda()
                    pred_topk, _ = self.model(x.cuda(), topk_adj_t.unsqueeze(0),
                                              batch_num_nodes=[sub_nodes.cpu().numpy()])

                    pred_topk = nn.Softmax(dim=0)(pred_topk[0])

                    topk_inv_diff += (pred_t[pred_label] - pred_topk[pred_label]).item()

                    if torch.argmax(pred_topk) == pred_label:
                        topk_inv_flips += 1.0

                rule_acc_count += 1.0
                mask_density = np.sum(masked_adj) / np.sum(adj.cpu().numpy())
                avg_mask_density += mask_density

            if scheduler is not None:
                scheduler.step()


            if self.print_training:
                print(
                    "epoch: ",
                    epoch,
                    "; loss: ",
                    loss.item(),
                )

            myfile.write("\n epoch: {}, loss: {}".format(epoch, loss.item()))

            # plot cmap for graphs' node features
            io_utils.plot_cmap_tb(self.writer, "tab20", 20, "tab20_cmap")

            print(stats)

            myfile.write("\n New size_cf : {}".format(self.args.size_c))

            if epoch %10 == 0:
                explainer_sum = 0.0
                model_sum = 0.0
                for p in explainer.parameters():
                    explainer_sum += torch.sum(p).item()
                for p in self.model.parameters():
                    model_sum += torch.sum(p).item()

                myfile.write("\n explainer params sum: {}, model params sum: {}".format(explainer_sum, model_sum))

                f_path = self.args.prefix + "explainer_" + self.args.bmname + "_pgeboundary.pth.tar"
                save_path = os.path.join(log_path, f_path)
                torch.save(explainer.state_dict(), save_path)
                myfile.write("\n ckpt saved at {}".format(save_path))
            if epoch % 100 == 0:
                f_path = self.args.prefix + "explainer_" + self.args.bmname + "_pgeboundary_ep_" + str(epoch) + ".pth.tar"
                save_path = os.path.join(log_path, f_path)
                torch.save(explainer.state_dict(), save_path)
                myfile.write("\n ckpt saved at {}".format(save_path))
            myfile.close()

        myfile.close()
        if test_graph_indices is not None:
            print("EVALUATING")
            self.eval_graphs(args, test_graph_indices, explainer)


class ExplainModule(nn.Module):
    def __init__(
        self,
        model,
        num_nodes,
        emb_dims,
        device,
        args
    ):
        super(ExplainModule, self).__init__()
        self.device = device

        self.model = model.to(self.device)
        self.num_nodes = num_nodes

        input_dim = np.sum(emb_dims)

        self.elayers = nn.Sequential(
            nn.Linear(input_dim, 64),
            nn.ReLU(),
            nn.Linear(64, 1)
        ).to(self.device)

        rc = torch.unsqueeze(torch.arange(0, self.num_nodes), 0).repeat([self.num_nodes,1]).to(torch.float32)
        self.row = torch.reshape(rc.T,[-1]).to(self.device)
        self.col = torch.reshape(rc,[-1]).to(self.device)

        self.softmax = nn.Softmax(dim=-1)

        self.mask_act = 'sigmoid'
        self.args = args

        self.coeffs = {
            "size": 0.012,#mutag
            "feat_size": 0.0,
            "ent": 0.0,
            "feat_ent": 0.1,
            "grad": 0,
            "lap": 1.0,
            "weight_decay": 0,
            "sample_bias": 0
        }

    def _masked_adj(self,mask,adj):

        mask = mask.to(self.device)
        sym_mask = mask
        sym_mask = (sym_mask.clone() + sym_mask.clone().T) / 2

        # Create sparse tensor TODO: test and "maybe" a transpose is needed somewhere
        sparseadj = torch.sparse_coo_tensor(
            indices=torch.transpose(torch.cat([torch.unsqueeze(torch.Tensor(adj.row),-1), torch.unsqueeze(torch.Tensor(adj.col),-1)], dim=-1), 0, 1).to(torch.int64),
            values=adj.data,
            size=adj.shape
        )

        adj = sparseadj.coalesce().to_dense().to(torch.float32).to(self.device) #FIXME: tf.sparse.reorder was also applied, but probably not necessary. Maybe it needs a .coalesce() too tho?
        self.adj = adj

        masked_adj = torch.mul(adj,sym_mask)

        num_nodes = adj.shape[0]
        ones = torch.ones((num_nodes, num_nodes))
        diag_mask = ones.to(torch.float32) - torch.eye(num_nodes)
        diag_mask = diag_mask.to(self.device)
        return torch.mul(masked_adj,diag_mask)

    def mask_density(self, adj):
        mask_sum = torch.sum(self.masked_adj).cpu()
        adj_sum = torch.sum(adj)
        return mask_sum / adj_sum

    def concrete_sample(self, log_alpha, beta=1.0, training=True):
        """Uniform random numbers for the concrete distribution"""

        if training:
            bias = self.coeffs['sample_bias']
            random_noise = bias + torch.FloatTensor(log_alpha.shape).uniform_(bias, 1.0-bias)
            random_noise = random_noise.to(self.device)
            gate_inputs = torch.log(random_noise) - torch.log(1.0 - random_noise)
            gate_inputs = (gate_inputs.clone() + log_alpha) / beta
            gate_inputs = torch.sigmoid(gate_inputs)
        else:
            gate_inputs = torch.sigmoid(log_alpha)

        return gate_inputs

    def forward(self, inputs, node_idx=None, training=None):
        x, embed, adj, tmp, label, sub_nodes = inputs
        if not isinstance(x, torch.Tensor):
            x = torch.tensor(x)
        x = x.to(self.device)
        if not isinstance(adj, torch.Tensor):
            adj = torch.tensor(adj)
        if not isinstance(label, torch.Tensor):
            label = torch.tensor(label)
        adj = adj.to(self.device)
        # embed = embed.to('cpu')
        self.label = label
        self.tmp = tmp

        row = self.row.type(torch.LongTensor).to(self.device)#('cpu')
        col = self.col.type(torch.LongTensor).to(self.device)
        if not isinstance(embed[row], torch.Tensor):
            f1 = torch.Tensor(embed[row]).to(self.device)   # .to(self.device)  # <-- torch way to do tf.gather(embed, self.row)
            f2 = torch.Tensor(embed[col]).to(self.device)
        else:
            f1 = embed[row]  # .to(self.device)  # <-- torch way to do tf.gather(embed, self.row)
            f2 = embed[col]

        h = torch.cat([f1, f2], dim=-1)

        h = h.to(self.device)
        for elayer in self.elayers:
            h = elayer(h)


        self.values = torch.reshape(h, [-1])

        values = self.concrete_sample(self.values, beta=tmp, training=training)

        sparsemask = torch.sparse.FloatTensor(
            indices=torch.transpose(torch.cat([torch.unsqueeze(self.row, -1), torch.unsqueeze(self.col,-1)], dim=-1), 0, 1).to(torch.int64),
            values=values,
            size=[self.num_nodes,self.num_nodes]
        ).to(self.device)
        sym_mask = sparsemask.coalesce().to_dense().to(torch.float32)  #FIXME: again a reorder() is omitted, maybe coalesce

        self.mask = sym_mask

        sym_mask = (sym_mask + sym_mask.T) / 2

        masked_adj = torch.mul(adj, sym_mask)
        self.masked_adj = masked_adj


        inverse_mask = (1.0 - sym_mask)
        orig_adj = adj + 0.
        inverse_masked_adj = torch.mul(adj, inverse_mask)
        self.inverse_masked_adj = inverse_masked_adj


        if self.args.node_mask:
            inverse_node_mask = torch.max(inverse_masked_adj, dim=1)[0]
            self.inverse_node_mask = inverse_node_mask


        x = torch.unsqueeze(x.detach().requires_grad_(True),0).to(torch.float32)        # Maybe needs a .clone()
        adj = torch.unsqueeze(self.masked_adj,0).to(torch.float32 )
        x.to(self.device)
        if sub_nodes is not None:
            sub_num_nodes_l = [sub_nodes.cpu().numpy()]
        else:
            sub_num_nodes_l = None

        inv_embed = None
        inv_res = None
        res = None
        g_embed = None
        if self.args.boundary_c > 0.0:
            # graph mode
            output, g_embed = self.model(x, adj, batch_num_nodes=sub_num_nodes_l)
            res = self.softmax(output)

        if self.args.inverse_boundary_c > 0.0:
            if self.args.node_mask:
                one_hot_rand = np.eye(x.shape[2])[np.random.choice(x.shape[2], x.shape[1])]
                one_hot_rand[sub_num_nodes_l[0]:,:] = 0.
                one_hot_rand_t = torch.from_numpy(one_hot_rand).float().cuda()
                one_hot_rand_t = one_hot_rand_t.unsqueeze(0)


                inverse_node_mask = inverse_node_mask.unsqueeze(1).expand(x.shape[1],x.shape[2])
                inverse_node_mask = inverse_node_mask.unsqueeze(0)
                x = x*inverse_node_mask
                # x = x*inverse_node_mask + (1.0 - inverse_node_mask)*one_hot_rand_t

                # inv_adj = orig_adj.unsqueeze(0).float()

            inv_adj = torch.unsqueeze(self.inverse_masked_adj, 0).to(torch.float32)

            # graph mode
            inv_output, inv_embed = self.model(x, inv_adj, batch_num_nodes=sub_num_nodes_l)
            inv_res = self.softmax(inv_output)


        return res, masked_adj, g_embed, inv_embed, inv_res


    def loss(self, pred, pred_label, inv_pred=None, inv_pred_label=None):
        """
        Args:
            pred: prediction made by current model
            pred_label: the label predicted by the original model.
        """
        if inv_pred_label is None:
            inv_pred_label = pred_label

        pred_loss = torch.zeros(1).cuda()
        if pred is not None:
            pred_reduce = pred[0]
            gt_label_node = self.label
            logit = pred_reduce[gt_label_node]

        pred_loss = -torch.log(logit)* self.args.boundary_c

        inv_pred_loss = torch.zeros(1).cuda()

        if inv_pred is not None:
            inv_pred_reduce = inv_pred[0]
            gt_label_node = self.label
            inv_logit = inv_pred_reduce[gt_label_node]

        inv_pred_loss = -1*self.args.inverse_boundary_c/(torch.log(inv_logit)+1e-6)

        mask = self.mask
        if self.args.size_c > -0.001:
            size_loss = self.args.size_c * torch.sum(mask)
        else:
            size_loss = self.coeffs["size"] * torch.sum(mask)

        # entropy
        mask = mask *0.99+0.005
        mask_ent = -mask * torch.log(mask) - (1 - mask) * torch.log(1 - mask)
        if self.args.ent_c > -0.001:
            mask_ent_loss = self.args.ent_c * torch.mean(mask_ent)
        else:
            mask_ent_loss = self.coeffs["ent"] * torch.mean(mask_ent)

        loss = pred_loss + inv_pred_loss + size_loss + mask_ent_loss

        # print("inv_pred_loss: ", inv_pred_loss.item(), "pred_loss: ", pred_loss.item(), "size_loss: ", size_loss.item(), "mask ent loss: ", mask_ent_loss.item())
        # print("total loss: ", loss.item())
        return loss, inv_pred_loss.item()

